#!/usr/bin/python

import sys,os,re,glob,math,glob,signal,traceback
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
import random as pyrandom
from optparse import OptionParser
from pylab import *
from scipy.ndimage import interpolation,filters
import ocrolib
from ocrolib import docproc,Record,mlp,dbhelper
from scipy import stats
signal.signal(signal.SIGINT,lambda *args:sys.exit(1))

parser = OptionParser(usage="""
%prog [options] -o output.cmodel input.db ...

Trains models based on a cluster database.

For faster speed and better memory usage, use the "-b" option, which buffers
samples in a 1bpp buffer (only binary input patterns); however, this only works
with binary inputs and feature extractors that generate (approximately) binary
data.  For example, it works for binary character images and ScaledExtractor, but not
if you use grayscale character images or StandardExtractor.

If you have lots of training data, try "-E ScaledFeatureExtractor -b", for handwriting
recognition, "-E StandardExtractor" is a better choice.

You can choose different kinds of feature extractors with the -E flag.  
Some possible values are: ScaledExtractor (raw grayscale image rescaled to a target size) 
and BiggestCcExtractor (biggest connected component only, otherwise treated like scaledfe).
You can find additional components by running "ocropus components" and looking for
implementors of IExtractor.

Common parameters for the model are:

-m '=ocrolib.mlp.AutoMlpModel()'
""")

parser.add_option("-o","--output",help="output model name",default=None)
parser.add_option("-m","--model",help="IModel name",default="ocrolib.mlp.AutoMlpModel()")
parser.add_option("-K","--nrounds",help="number of training rounds",type=int,default=48)
parser.add_option("-b","--bits",help="buffer training data with 1 bpp",action="store_true")
parser.add_option("-t","--table",help="database table to use for training",default="chars")
parser.add_option("-u","--unlabeled",help="treat unlabeled ('_') as reject",action="store_true")
parser.add_option("-v","--verbose",help="verbose",action="store_true")
parser.add_option("-E","--extractor",help="feature extractor",default="StandardExtractor")
parser.add_option("-N","--nvariants",help="number of variants to generate",default=0,type="int")
parser.add_option("-D","--distortion",help="maximum distortion",default=0.2,type="float")
parser.add_option("-d","--display",help="debug display",action="store_true")
parser.add_option("-Q","--maxthreads",help="max # of threads for training",type=int,default=4)
# different training modalities
parser.add_option("-g","--nogeometry",help="do not add geometric information",action="store_true")
parser.add_option("-r","--noreject",help="disable reject class",action="store_true")
parser.add_option("-1","--single",help="train only single chars, treat multiple as reject (combine with -r if you like)",action="store_true")
parser.add_option("-R","--rejectonly",help="only train reject/non-reject classifier",action="store_true")
parser.add_option("-L","--lengthpred",help="train a length predictor",action="store_true")
parser.add_option("-M","--multionly",help="only train multiple charactes",action="store_true")
# the following are for limiting training samples
parser.add_option("-n","--limit",help="limit total training to n samples",default=999999999,type="int")
parser.add_option("-T","--threshold",help="threshold for estimating ntrain (either a percentile or a class)",default="70")
parser.add_option("-F","--rejectfactor",help="multiple of per class ntrain",type=float,default=20.0)
(options,args) = parser.parse_args()

assert options.output is not None,"you must provide a -o flag"

if len(args)<1:
    parser.print_help()
    sys.exit(0)

mlp.maxthreads.value = options.maxthreads
mlp.maxthreads_train.value = options.maxthreads

# some utility functions we're going to need later

def pad(image,dx,dy,bgval=None):
    """Pad an image by the given amounts in the x and y directions.
    If bgval is not given, the minimum of the input image is used."""
    if bgval is None: bgval = amin(image)
    h,w = image.shape
    result = zeros((h+2*dx,w+2*dy))
    result[:,:] = bgval
    result[dx:-dx,dy:-dy] = image
    return result

def distort(image,sx,sy,sigma=10.0):
    """Distort the image by generating smoothed noise for the
    x and y displacement vectors.  The displacement amounts sx and
    sy are in terms of fractions of the image width and height.
    Image should be an narray."""
    h0,w0 = image.shape
    my = sy*h0
    mx = sx*w0
    image = pad(image,int(my+1),int(mx+1))
    h,w = image.shape
    dy = filters.gaussian_filter(rand(*image.shape)-0.5,sigma)
    dy *= my/amax(abs(dy))
    dx = filters.gaussian_filter(rand(*image.shape)-0.5,sigma)
    dx *= mx/amax(abs(dx))
    dy += arange(h)[:,newaxis]
    dx += arange(w)[newaxis,:]
    distorted = interpolation.map_coordinates(image,array([dy,dx]),order=1)
    # print amax(abs(dy)),amax(abs(dx)),amax(abs(distorted-image))
    return distorted

# unlabeled-as-reject implies that we train reject classes

if options.unlabeled: options.noreject = False

if len(args)<1:
    print "must specify at least one character database as argument"
    sys.exit(1)

output = options.output

if output is None:
    output= os.path.splitext(args[0])[0]+".cmodel"
    print "output to",output

if os.path.exists(output):
    print output,"already exists"
    sys.exit(1)

if not ".cmodel" in output and not ".pymodel" in output:
    print "output",output,"should end in .cmodel or .pymodel"
    sys.exit(1)

# create a Python window if requested

if options.display:
    ion()
    show()

# initialize the C++ debugging graphics (just in case it's needed for
# debugging the native code feature extractors)

classifier = ocrolib.make_IModel(options.model)
print "classifier",classifier

# little utility function for displaying character training progress window

fig = 0

def show_char(image,cls):
    """Display the character in a grid, wrapping around every now and then.
    Used for showing progress in loading the training characters."""
    global fig
    r = 4
    if fig%r**2==0: clf(); gray()
    subplot(r,r,fig%r**2+1)
    fig += 1
    imshow(image/255.0)
    text(3,3,str(cls),color="red",fontsize=14)
    draw()
    ginput(1,timeout=0.1)

print "training..."

count = 0
nchar = 0
skip0 = 0
skip3 = 0
nreject = 0

def iterate_db(arg):
    print "===",arg,"==="

class Counter(dict):
    def __getitem__(self,index):
        return self.setdefault(index,0)

totals = Counter()
actual = Counter()

for arg in args:
    if not os.path.exists(arg):
        print ""+arg+": not found"
        sys.exit(1)

for arg in args:
    print "===",arg
    db = dbhelper.chardb(arg,options.table)
    print "total",db.execute("select count(*) from chars").next()[0]

    print "# determining per-class cutoff"

    classes = [tuple(x) for x in db.execute("select cls,count(*) from chars group by cls")]
    assert len(classes)>=2,"too few classes in database; got %s"%classes
    counts= array([x[1] for x in classes],'f')
    threshold = options.threshold
    if re.match(r'[0-9][0-9.]+',threshold):
        threshold = float(threshold)
        assert len(counts)>1
        assert threshold>=0 and threshold<=100
        ntrain = int(stats.scoreatpercentile(counts,threshold))
    else:
        for k,v in classes:
            if k==threshold:
                ntrain = v
                break
        if ntrain is None:
            print "class not found:",threshold
    print "classes",len(counts),"stats",amin(counts),median(counts),mean(counts),amax(counts),"ntrain",ntrain

    print "# sampling the classes"

    all_ids = []
    for k,v in classes:
        ids = list(db.execute("select id from chars where cls=?",[k]))
        ids = [x[0] for x in ids]
        if k=="" or ord(k[0])<33: continue
        if k=="_": continue
        if k=="~": ntrain = int(ntrain*options.rejectfactor)
        if len(ids)>ntrain: ids = pyrandom.sample(ids,ntrain)
        print " %s:%d"%(k,len(ids)),; sys.stdout.flush()
        all_ids += ids
    print

    counter = Counter()

    print "# training",len(all_ids)

    for id in all_ids:
        c = db.execute("select image,cls,rel from chars where id=?",[id]).next()
        cluster = Record(image=dbhelper.blob2image(c.image),cls=c.cls,rel=c.rel)

        # check whether we already have enough characters
        if count>options.limit: 
            break

        cls = cluster.cls
        if cls is None: 
            cls = "_"

        counter[cls] += 1
        totals[cls] += 1

        # empty strings have no class at all, so we skip them
        if len(cls)==0:
            skip0 += 1
            continue
        # can't train transcriptions longer than three characters
        if len(cls)>3:
            skip3 += 1
            continue
        # if the user requested only single character training, skip multi-character transcriptions
        if options.single and len(cls)>1: continue
        # if the user requested it, treat unlabeled samples ("_") as reject classes
        if options.unlabeled and cls=="_": cls = "~"
        # if the user didn't want reject training, skip all reject classes
        if options.noreject and cls=="~": continue
        # skip any remaining unlabeled samples
        if cls=="_": continue

        if cls=="~": nreject += 1

        # count the actually trained samples
        actual[cls] += 1

        geometry=None
        if not options.nogeometry:
            if cluster.rel is None or len(cluster.rel)<2:
                print "training with geometry requested, but",arg,"lacks geometry information"
                print "use the -g flag to turn off training with geometry"
                sys.exit(1)
            geometry = docproc.rel_geo_normalize(cluster.rel)

        # add the image to the classifier
        image = cluster.image/255.0
        if options.display and nchar%1000==0: show_char(image,cls)
        if min(image.shape)<3: continue
        if amax(image)==amin(image): continue
        if geometry is not None:
            classifier.cadd(image,cls,geometry=geometry)
        else:
            classifier.cadd(image,cls)
        count += 1

        # if the user requested automatically generated variants,
        # generate them and add them to the classifier as well
        for i in range(options.nvariants):
            if count>options.limit: break
            sx = options.distortion
            sy = options.distortion
            distorted = distort(image,sx,sy)
            if options.display and nchar%1000==0: show_char(distorted,cls)
            try:
                if geometry is not None:
                    classifier.cadd(image,cls,geometry=geometry)
                else:
                    classifier.cadd(image,cls)
                count += 1
            except:
                traceback.print_exc()
                continue
        nchar += 1
        if nchar%10000==0: print "training",nchar,count

print
print "summary statistics:"
print sorted(list(actual.items()))
print
print "training",count,"variants representing",nchar,"training characters"
if skip0>0: print "skipped",skip0,"zero-length transcriptions"
if skip3>0: print "skipped",skip3,"transcriptions that were more than three characters long"
print

# provide some feedback about the training process

if options.verbose:
    if "info" in dir(classifier): 
        classifier.info()
    if "getExtractor" in dir(classifiers) and "info" in dir(classifier.getExtractor()):
        classifier.getExtractor().info()

# now perform the actual training
# (usually, this is AutoMLP, a multi-threaded, long-running stochastic
# gradient descent training for a multi-layer perceptron)
    
print "starting classifier training"
round = 0
for progress in classifier.updateModel1(verbose=1):
    print "ocropus-ctrain:",round,progress
    modelbase,_ = os.path.splitext(options.output)
    ocrolib.save_component("%s.%03d.cmodel"%(modelbase,round),classifier)
    round += 1
    if round>=options.nrounds: break

# provide some feedback about the training process

if options.verbose:
    classifier.info()

# save the resulting model

print "saving",output
ocrolib.save_component(output,classifier)

